package me.midday.matrix;

/**
 * Created by wuzheng on 17-11-9.
 */
public class Matrix {
    private Shape shape;

    private double[][] data;

    public Matrix(int rows, int cols){
        shape = new Shape(rows, cols);
        data = new double[rows][cols];
    }

    public Shape getShape(){
        return this.shape;
    }

    public double[] getRow(int row){
        return this.data[row];
    }

    public double[] getCol(int col){
        return this.data[col];
    }

    private double vecDot(double[] a, double[] b) throws MatrixException{
        if (a.length != b.length){
            throw new MatrixException("length need samle");
        }
        double result = 0.0;
        for (int i=0; i < a.length; i++ ){
            result += a[i] * b[i];
        }
        return result;
    }

    public Matrix(Shape shape){
        this.shape = shape;
        data = new double[shape.getRows()][shape.getCols()];
    }


    public Matrix dot(Matrix other) throws MatrixException {

        if(this.shape.getCols() != other.shape.getRows()){
            throw new MatrixException(String.format("con't dot matrix with shape (%d, %d) and shape (%d, %d)",
                    this.shape.getRows(), this.shape.getCols(), other.shape.getRows(), other.shape.getCols()));
        }

        Matrix dotResult = new Matrix(this.shape.getRows(), other.shape.getCols());
        for (int r = 0; r < this.shape.getRows(); r++){
            double[] ar = getRow(r);
            for(int c = 0 ; c < other.shape.getCols(); c++){
                double[] bc = getCol(c);
                dotResult.data[r][c] = vecDot(ar, bc);
            }
        }
        return dotResult;
    }
    public void setValue(double value, int row, int col){
        this.data[row][col] = value;
    }

    public double getValue(int row, int col){
        return this.data[row][col];
    }

    public Matrix sumByRow(){
        Matrix res = new Matrix(this.shape.getRows(), 1);
        for (int row = 0; row < this.shape.getRows(); row++ ){
            double value = 0.0;
            for (int col = 0 ; col < this.shape.getCols(); col++){
                 value += this.data[row][col];
            }
            res.setValue(value, row, 1);
        }
        return res;
    }

    public Matrix sumByCol(){
        Matrix res = new Matrix(1, this.shape.getCols());
        for (int col = 0; col < this.shape.getCols(); col++ ){
            double value = 0.0;
            for (int row = 0; row < this.shape.getRows(); row++ ){
                value += this.data[col][row];
            }
            res.setValue(value, 1, col);
        }
        return res;
    }

    public Matrix sumAll(){
        Matrix res = new Matrix(1, 1);
        double value = 0.0;
        for (int row = 0; row < this.shape.getRows(); row++){
            for (int col = 0; col < this.shape.getCols(); col++){
                value += this.data[row][col];
            }
        }
        res.setValue(value, 0, 0);
        return res;
    }

    public Matrix sum(int axis ) throws MatrixException {
        if (axis == -1) {
            return sumAll();
        }
        else if(axis == 0){
            return sumByRow();
        }
        else if(axis == 1){
            return sumByCol();
        }else {
            throw new MatrixException(" axis matrix is wrong");
        }
    }

    public Matrix add(double value){
        Matrix res = new Matrix(this.shape);
        for (int row = 0; row < this.shape.getRows(); row++){
            for (int col=0; col < this.shape.getCols(); col++){
                res.setValue(this.getValue(row, col) + value , row, col);
            }
        }
        return res;
    }

    public Matrix add(Matrix mat) throws MatrixException {
        if (!shape.equals(mat.getShape())){
            throw new MatrixException("shape is not equal");
        }
        Matrix res = new Matrix(shape);
        for (int row = 0; row < this.shape.getRows(); row++ ){
            for (int col= 0; col<this.shape.getCols(); col++){
                res.setValue(this.getValue(row, col)+ mat.getValue(row, col), row, col);
            }
        }
        return res;
    }

    public Matrix multi(double mu){

        Matrix res = new Matrix(this.shape);
        for (int row = 0; row < this.shape.getRows(); row++){
            for (int col = 0 ; col < this.shape.getCols(); col++ ){
                res.setValue(mu * this.getValue(row, col), row, col);
            }
        }
        return res;
    }

    public Matrix sub(Matrix mat) throws MatrixException {

        if (!this.shape.equals(mat.shape)){
            throw new MatrixException("shape is not equal");
        }
        Matrix res = new Matrix(this.shape);
        for (int row = 0; row < this.shape.getRows(); row++ ){
            for (int col = 0; col < this.shape.getCols(); col++){
                res.setValue(this.getValue(row, col) - mat.getValue(row, col), row, col);
            }
        }
        return res;
    }

    @Override
    public String toString() {
        StringBuffer buffer = new StringBuffer();
        for (int row = 0 ; row < shape.getRows(); row++){
            for (int col = 0; col < shape.getCols(); col++){
                buffer.append(" " + getValue(row, col) + " ");
            }
            buffer.append("\n");
        }
        return buffer.toString();
    }
}
